{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Chroma"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "이 노트북에서는 Chroma 벡터스토어를 시작하는 방법을 다룹니다.\n",
    "\n",
    "Chroma는 개발자의 생산성과 행복에 초점을 맞춘 AI 네이티브 오픈 소스 벡터 데이터베이스입니다. Chroma는 Apache 2.0에 따라 라이선스가 부여됩니다. \n",
    "\n",
    "\n",
    "**참고링크**\n",
    "\n",
    "- [Chroma LangChain 문서](https://python.langchain.com/v0.2/docs/integrations/vectorstores/chroma/)\n",
    "- [Chroma 공식문서](https://docs.trychroma.com/getting-started)\n",
    "- [LangChain 지원 VectorStore 리스트](https://python.langchain.com/v0.2/docs/integrations/vectorstores/)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# API 키를 환경변수로 관리하기 위한 설정 파일\n",
    "from dotenv import load_dotenv\n",
    "\n",
    "# API 키 정보 로드\n",
    "load_dotenv()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# LangSmith 추적을 설정합니다. https://smith.langchain.com\n",
    "# !pip install langchain-teddynote\n",
    "from langchain_teddynote import logging\n",
    "\n",
    "# 프로젝트 이름을 입력합니다.\n",
    "logging.langsmith(\"CH09-VectorStores\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "샘플 데이터셋을 로드합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_community.document_loaders import TextLoader\n",
    "from langchain_openai.embeddings import OpenAIEmbeddings\n",
    "from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
    "from langchain_chroma import Chroma\n",
    "\n",
    "\n",
    "# 텍스트 분할\n",
    "text_splitter = RecursiveCharacterTextSplitter(chunk_size=600, chunk_overlap=0)\n",
    "\n",
    "# 텍스트 파일을 load -> List[Document] 형태로 변환\n",
    "loader1 = TextLoader(\"data/nlp-keywords.txt\")\n",
    "loader2 = TextLoader(\"data/finance-keywords.txt\")\n",
    "\n",
    "# 문서 분할\n",
    "split_doc1 = loader1.load_and_split(text_splitter)\n",
    "split_doc2 = loader2.load_and_split(text_splitter)\n",
    "\n",
    "# 문서 개수 확인\n",
    "len(split_doc1), len(split_doc2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## VectorStore 생성"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 벡터 저장소 생성 (from_documents)\n",
    "\n",
    "`from_documents` 클래스 메서드는 문서 리스트로부터 벡터 저장소를 생성합니다. \n",
    "\n",
    "**매개변수**\n",
    "\n",
    "- `documents` (List[Document]): 벡터 저장소에 추가할 문서 리스트\n",
    "- `embedding` (Optional[Embeddings]): 임베딩 함수. 기본값은 None\n",
    "- `ids` (Optional[List[str]]): 문서 ID 리스트. 기본값은 None\n",
    "- `collection_name` (str): 생성할 컬렉션 이름.\n",
    "- `persist_directory` (Optional[str]): 컬렉션을 저장할 디렉토리. 기본값은 None\n",
    "- `client_settings` (Optional[chromadb.config.Settings]): Chroma 클라이언트 설정\n",
    "- `client` (Optional[chromadb.Client]): Chroma 클라이언트 인스턴스\n",
    "- `collection_metadata` (Optional[Dict]): 컬렉션 구성 정보. 기본값은 None\n",
    "\n",
    "**참고**\n",
    "\n",
    "- `persist_directory`가 지정되면 컬렉션이 해당 디렉토리에 저장됩니다. 지정되지 않으면 데이터는 메모리에 임시로 저장됩니다.\n",
    "- 이 메서드는 내부적으로 `from_texts` 메서드를 호출하여 벡터 저장소를 생성합니다.\n",
    "- 문서의 `page_content`는 텍스트로, `metadata`는 메타데이터로 사용됩니다.\n",
    "\n",
    "**반환값**\n",
    "\n",
    "- `Chroma`: 생성된 Chroma 벡터 저장소 인스턴스"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "생성시 `documents` 매개변수로 `Document` 리스트를 전달합니다. embedding 에 활용할 임베딩 모델을 지정하며, `namespace` 의 역할을 하는 `collection_name` 을 지정할 수 있습니다.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# DB 생성\n",
    "db = Chroma.from_documents(\n",
    "    documents=split_doc1, embedding=OpenAIEmbeddings(), collection_name=\"my_db\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`persist_directory` 지정시 disk 에 파일 형태로 저장합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 저장할 경로 지정\n",
    "DB_PATH = \"./chroma_db\"\n",
    "\n",
    "# 문서를 디스크에 저장합니다. 저장시 persist_directory에 저장할 경로를 지정합니다.\n",
    "persist_db = Chroma.from_documents(\n",
    "    split_doc1, OpenAIEmbeddings(), persist_directory=DB_PATH, collection_name=\"my_db\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "아래의 코드를 실행하여 `DB_PATH` 에 저장된 데이터를 로드합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 디스크에서 문서를 로드합니다.\n",
    "persist_db = Chroma(\n",
    "    persist_directory=DB_PATH,\n",
    "    embedding_function=OpenAIEmbeddings(),\n",
    "    collection_name=\"my_db\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "불러온 VectorStore 에서 저장된 데이터를 확인합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 저장된 데이터 확인\n",
    "persist_db.get()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "만약 `collection_name` 을 다르게 지정하면 저장된 데이터가 없기 때문에 아무런 결과도 얻지 못합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 디스크에서 문서를 로드합니다.\n",
    "persist_db2 = Chroma(\n",
    "    persist_directory=DB_PATH,\n",
    "    embedding_function=OpenAIEmbeddings(),\n",
    "    collection_name=\"my_db2\",\n",
    ")\n",
    "\n",
    "# 저장된 데이터 확인\n",
    "persist_db2.get()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 벡터 저장소 생성 (from_texts)\n",
    "\n",
    "`from_texts` 클래스 메서드는 텍스트 리스트로부터 벡터 저장소를 생성합니다.\n",
    "\n",
    "**매개변수**\n",
    "\n",
    "- `texts` (List[str]): 컬렉션에 추가할 텍스트 리스트\n",
    "- `embedding` (Optional[Embeddings]): 임베딩 함수. 기본값은 None\n",
    "- `metadatas` (Optional[List[dict]]): 메타데이터 리스트. 기본값은 None\n",
    "- `ids` (Optional[List[str]]): 문서 ID 리스트. 기본값은 None\n",
    "- `collection_name` (str): 생성할 컬렉션 이름. 기본값은 '_LANGCHAIN_DEFAULT_COLLECTION_NAME'\n",
    "- `persist_directory` (Optional[str]): 컬렉션을 저장할 디렉토리. 기본값은 None\n",
    "- `client_settings` (Optional[chromadb.config.Settings]): Chroma 클라이언트 설정\n",
    "- `client` (Optional[chromadb.Client]): Chroma 클라이언트 인스턴스\n",
    "- `collection_metadata` (Optional[Dict]): 컬렉션 구성 정보. 기본값은 None\n",
    "\n",
    "**참고**\n",
    "\n",
    "- `persist_directory`가 지정되면 컬렉션이 해당 디렉토리에 저장됩니다. 지정되지 않으면 데이터는 메모리에 임시로 저장됩니다.\n",
    "- `ids`가 제공되지 않으면 UUID를 사용하여 자동으로 생성됩니다.\n",
    "\n",
    "**반환값**\n",
    "\n",
    "- 생성된 벡터 저장소 인스턴스"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 문자열 리스트로 생성\n",
    "db2 = Chroma.from_texts(\n",
    "    [\"안녕하세요. 정말 반갑습니다.\", \"제 이름은 테디입니다.\"],\n",
    "    embedding=OpenAIEmbeddings(),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 데이터를 조회합니다.\n",
    "db2.get()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 유사도 검색\n",
    "\n",
    "`similarity_search` 메서드는 Chroma 데이터베이스에서 유사도 검색을 수행합니다. 이 메서드는 주어진 쿼리와 가장 유사한 문서들을 반환합니다.\n",
    "\n",
    "**매개변수**\n",
    "\n",
    "- `query` (str): 검색할 쿼리 텍스트\n",
    "- `k` (int, 선택적): 반환할 결과의 수. 기본값은 4입니다.\n",
    "- `filter` (Dict[str, str], 선택적): 메타데이터로 필터링. 기본값은 None입니다.\n",
    "\n",
    "**참고**\n",
    "\n",
    "- `k` 값을 조절하여 원하는 수의 결과를 얻을 수 있습니다.\n",
    "- `filter` 매개변수를 사용하여 특정 메타데이터 조건에 맞는 문서만 검색할 수 있습니다.\n",
    "- 이 메서드는 점수 정보 없이 문서만 반환합니다. 점수 정보도 필요한 경우 `similarity_search_with_score` 메서드를 직접 사용하세요.\n",
    "\n",
    "**반환값**\n",
    "\n",
    "- `List[Document]`: 쿼리 텍스트와 가장 유사한 문서들의 리스트"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "db.similarity_search(\"TF IDF 에 대하여 알려줘\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`k` 값에 검색 결과의 개수를 지정할 수 있습니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "db.similarity_search(\"TF IDF 에 대하여 알려줘\", k=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`filter` 에 `metadata` 정보를 활용하여 검색 결과를 필터링 할 수 있습니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# filter 사용\n",
    "db.similarity_search(\n",
    "    \"TF IDF 에 대하여 알려줘\", filter={\"source\": \"data/nlp-keywords.txt\"}, k=2\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "다음은 `filter` 에서 다른 `source` 를 사용하여 검색한 결과를 확인합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# filter 사용\n",
    "db.similarity_search(\n",
    "    \"TF IDF 에 대하여 알려줘\", filter={\"source\": \"data/finance-keywords.txt\"}, k=2\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 벡터 저장소에 문서 추가\n",
    "\n",
    "`add_documents` 메서드는 벡터 저장소에 문서를 추가하거나 업데이트합니다.\n",
    "\n",
    "**매개변수**\n",
    "\n",
    "- `documents` (List[Document]): 벡터 저장소에 추가할 문서 리스트\n",
    "- `**kwargs`: 추가 키워드 인자\n",
    "  - `ids`: 문서 ID 리스트 (제공 시 문서의 ID보다 우선함)\n",
    "\n",
    "**참고**\n",
    "\n",
    "- `add_texts` 메서드가 구현되어 있어야 합니다.\n",
    "- 문서의 `page_content`는 텍스트로, `metadata`는 메타데이터로 사용됩니다.\n",
    "- 문서에 ID가 있고 `kwargs`에 ID가 제공되지 않으면 문서의 ID가 사용됩니다.\n",
    "- `kwargs`의 ID와 문서 수가 일치하지 않으면 ValueError가 발생합니다.\n",
    "\n",
    "**반환값**\n",
    "\n",
    "- `List[str]`: 추가된 텍스트의 ID 리스트\n",
    "\n",
    "**예외**\n",
    "\n",
    "- `NotImplementedError`: `add_texts` 메서드가 구현되지 않은 경우 발생"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_core.documents import Document\n",
    "\n",
    "# page_content, metadata, id 지정\n",
    "db.add_documents(\n",
    "    [\n",
    "        Document(\n",
    "            page_content=\"안녕하세요! 이번엔 도큐먼트를 새로 추가해 볼께요\",\n",
    "            metadata={\"source\": \"mydata.txt\"},\n",
    "            id=\"1\",\n",
    "        )\n",
    "    ]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# id=1 로 문서 조회\n",
    "db.get(\"1\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`add_texts` 메서드는 텍스트를 임베딩하고 벡터 저장소에 추가합니다.\n",
    "\n",
    "**매개변수**\n",
    "\n",
    "- `texts` (Iterable[str]): 벡터 저장소에 추가할 텍스트 리스트\n",
    "- `metadatas` (Optional[List[dict]]): 메타데이터 리스트. 기본값은 None\n",
    "- `ids` (Optional[List[str]]): 문서 ID 리스트. 기본값은 None\n",
    "\n",
    "**참고**\n",
    "\n",
    "- `ids`가 제공되지 않으면 UUID를 사용하여 자동으로 생성됩니다.\n",
    "- 임베딩 함수가 설정되어 있으면 텍스트를 임베딩합니다.\n",
    "- 메타데이터가 제공된 경우:\n",
    "  - 메타데이터가 있는 텍스트와 없는 텍스트를 분리하여 처리합니다.\n",
    "  - 메타데이터가 없는 텍스트의 경우 빈 딕셔너리로 채웁니다.\n",
    "- 컬렉션에 upsert 작업을 수행하여 텍스트, 임베딩, 메타데이터를 추가합니다.\n",
    "\n",
    "**반환값**\n",
    "\n",
    "- `List[str]`: 추가된 텍스트의 ID 리스트\n",
    "\n",
    "**예외**\n",
    "\n",
    "- `ValueError`: 복잡한 메타데이터로 인한 오류 발생 시, 필터링 방법 안내 메시지와 함께 발생"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "기존의 아이디에 추가하는 경우 `upsert` 가 수행되며, 기존의 문서는 대체됩니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 신규 데이터를 추가합니다. 이때 기존의 id=1 의 데이터는 덮어쓰게 됩니다.\n",
    "db.add_texts(\n",
    "    [\"이전에 추가한 Document 를 덮어쓰겠습니다.\", \"덮어쓴 결과가 어떤가요?\"],\n",
    "    metadatas=[{\"source\": \"mydata.txt\"}, {\"source\": \"mydata.txt\"}],\n",
    "    ids=[\"1\", \"2\"],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# id=1 조회\n",
    "db.get([\"1\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 벡터 저장소에서 문서 삭제\n",
    "\n",
    "`delete` 메서드는 벡터 저장소에서 지정된 ID의 문서를 삭제합니다.\n",
    "\n",
    "**매개변수**\n",
    "\n",
    "- `ids` (Optional[List[str]]): 삭제할 문서의 ID 리스트. 기본값은 None\n",
    "\n",
    "**참고**\n",
    "\n",
    "- 이 메서드는 내부적으로 컬렉션의 `delete` 메서드를 호출합니다.\n",
    "- `ids`가 None이면 아무 작업도 수행하지 않습니다.\n",
    "\n",
    "**반환값**\n",
    "\n",
    "- None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# id 1 삭제\n",
    "db.delete(ids=[\"1\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 문서 조회\n",
    "db.get([\"1\", \"2\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# where 조건으로 metadata 조회\n",
    "db.get(where={\"source\": \"mydata.txt\"})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 초기화(reset_collection)\n",
    "\n",
    "`reset_collection` 메서드는 벡터 저장소의 컬렉션을 초기화합니다.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 컬렉션 초기화\n",
    "db.reset_collection()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 초기화 후 문서 조회\n",
    "db.get()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 벡터 저장소를 검색기(Retriever)로 변환\n",
    "\n",
    "`as_retriever` 메서드는 벡터 저장소를 기반으로 VectorStoreRetriever를 생성합니다.\n",
    "\n",
    "**매개변수**\n",
    "\n",
    "- `**kwargs`: 검색 함수에 전달할 키워드 인자\n",
    "  - `search_type` (Optional[str]): 검색 유형 (`\"similarity\"`, `\"mmr\"`, `\"similarity_score_threshold\"`)\n",
    "  - `search_kwargs` (Optional[Dict]): 검색 함수에 전달할 추가 인자\n",
    "    - `k`: 반환할 문서 수 (기본값: 4)\n",
    "    - `score_threshold`: 최소 유사도 임계값\n",
    "    - `fetch_k`: MMR 알고리즘에 전달할 문서 수 (기본값: 20)\n",
    "    - `lambda_mult`: MMR 결과의 다양성 조절 (0~1, 기본값: 0.5)\n",
    "    - `filter`: 문서 메타데이터 필터링\n",
    "\n",
    "**반환값**\n",
    "\n",
    "- `VectorStoreRetriever`: 벡터 저장소 기반 검색기 인스턴스"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`DB` 를 생성합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# DB 생성\n",
    "db = Chroma.from_documents(\n",
    "    documents=split_doc1 + split_doc2,\n",
    "    embedding=OpenAIEmbeddings(),\n",
    "    collection_name=\"nlp\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "기본 값으로 설정된 4개 문서를 유사도 검색을 수행하여 조회합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "retriever = db.as_retriever()\n",
    "retriever.invoke(\"Word2Vec 에 대하여 알려줘\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "다양성이 높은 더 많은 문서 검색\n",
    "\n",
    "- `k`: 반환할 문서 수 (기본값: 4)\n",
    "- `fetch_k`: MMR 알고리즘에 전달할 문서 수 (기본값: 20)\n",
    "- `lambda_mult`: MMR 결과의 다양성 조절 (0~1, 기본값: 0.5, 0: 유사도 점수만 고려, 1: 다양성만 고려)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "retriever = db.as_retriever(\n",
    "    search_type=\"mmr\", search_kwargs={\"k\": 6, \"lambda_mult\": 0.25, \"fetch_k\": 10}\n",
    ")\n",
    "retriever.invoke(\"Word2Vec 에 대하여 알려줘\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "MMR 알고리즘을 위해 더 많은 문서를 가져오되 상위 2개만 반환"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "retriever = db.as_retriever(search_type=\"mmr\", search_kwargs={\"k\": 2, \"fetch_k\": 10})\n",
    "retriever.invoke(\"Word2Vec 에 대하여 알려줘\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "특정 임계값 이상의 유사도를 가진 문서만 검색"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "retriever = db.as_retriever(\n",
    "    search_type=\"similarity_score_threshold\", search_kwargs={\"score_threshold\": 0.8}\n",
    ")\n",
    "\n",
    "retriever.invoke(\"Word2Vec 에 대하여 알려줘\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "가장 유사한 단일 문서만 검색"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "retriever = db.as_retriever(search_kwargs={\"k\": 1})\n",
    "\n",
    "retriever.invoke(\"Word2Vec 에 대하여 알려줘\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "특정 메타데이터 필터 적용"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "retriever = db.as_retriever(\n",
    "    search_kwargs={\"filter\": {\"source\": \"data/finance-keywords.txt\"}, \"k\": 2}\n",
    ")\n",
    "retriever.invoke(\"ESG 에 대하여 알려줘\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 멀티모달 검색\n",
    "\n",
    "Chroma는 멀티모달 컬렉션, 즉 여러 양식의 데이터를 포함하고 쿼리할 수 있는 컬렉션을 지원합니다."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 데이터 세트\n",
    "\n",
    "허깅페이스에서 호스팅되는 [coco object detection dataset](https://huggingface.co/datasets/detection-datasets/coco)의 작은 하위 집합을 사용합니다.\n",
    "\n",
    "데이터 세트의 모든 이미지 중 일부만 로컬로 다운로드하고 이를 사용하여 멀티모달 컬렉션을 생성합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from datasets import load_dataset\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "# COCO 데이터셋 로드\n",
    "dataset = load_dataset(\n",
    "    path=\"detection-datasets/coco\", name=\"default\", split=\"train\", streaming=True\n",
    ")\n",
    "\n",
    "# 이미지 저장 폴더와 이미지 개수 설정\n",
    "IMAGE_FOLDER = \"tmp\"\n",
    "N_IMAGES = 20\n",
    "\n",
    "# 그래프 플로팅을 위한 설정\n",
    "plot_cols = 5\n",
    "plot_rows = N_IMAGES // plot_cols\n",
    "fig, axes = plt.subplots(plot_rows, plot_cols, figsize=(plot_rows * 2, plot_cols * 2))\n",
    "axes = axes.flatten()\n",
    "\n",
    "# 이미지를 폴더에 저장하고 그래프에 표시\n",
    "dataset_iter = iter(dataset)\n",
    "os.makedirs(IMAGE_FOLDER, exist_ok=True)\n",
    "for i in range(N_IMAGES):\n",
    "    # 데이터셋에서 이미지와 레이블 추출\n",
    "    data = next(dataset_iter)\n",
    "    image = data[\"image\"]\n",
    "    label = data[\"objects\"][\"category\"][0]  # 첫 번째 객체의 카테고리를 레이블로 사용\n",
    "\n",
    "    # 그래프에 이미지 표시 및 레이블 추가\n",
    "    axes[i].imshow(image)\n",
    "    axes[i].set_title(label, fontsize=8)\n",
    "    axes[i].axis(\"off\")\n",
    "\n",
    "    # 이미지 파일로 저장\n",
    "    image.save(f\"{IMAGE_FOLDER}/{i}.jpg\")\n",
    "\n",
    "# 그래프 레이아웃 조정 및 표시\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![](./images/chroma-01.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Multimodal Embeddings\n",
    "\n",
    "Multimodal Embeddings 을 활용하여 이미지, 텍스트에 대한 Embedding 을 생성합니다.\n",
    "\n",
    "이번 튜토리얼에서는 OpenClipEmbeddingFunction 을 사용하여 이미지를 임베딩합니다.\n",
    "\n",
    "- [OpenCLIP](https://github.com/mlfoundations/open_clip/tree/main)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model 벤치마크\n",
    "\n",
    "| Model                              | Training data  | Resolution | # of samples seen | ImageNet zero-shot acc. |\n",
    "|------------------------------------|----------------|------------|-------------------|-------------------------|\n",
    "| ConvNext-Base                       | LAION-2B       | 256px      | 13B               | 71.5%                   |\n",
    "| ConvNext-Large                      | LAION-2B       | 320px      | 29B               | 76.9%                   |\n",
    "| ConvNext-XXLarge                    | LAION-2B       | 256px      | 34B               | 79.5%                   |\n",
    "| ViT-B/32                            | DataComp-1B    | 256px      | 34B               | 72.8%                   |\n",
    "| ViT-B/16                            | DataComp-1B    | 224px      | 13B               | 73.5%                   |\n",
    "| ViT-L/14                            | LAION-2B       | 224px      | 32B               | 75.3%                   |\n",
    "| ViT-H/14                            | LAION-2B       | 224px      | 32B               | 78.0%                   |\n",
    "| ViT-L/14                            | DataComp-1B    | 224px      | 13B               | 79.2%                   |\n",
    "| ViT-G/14                            | LAION-2B       | 224px      | 34B               | 80.1%                   |\n",
    "| ViT-L/14 ([Original CLIP](https://openai.com/research/clip)) | WIT            | 224px      | 13B               | 75.5%                   |\n",
    "| ViT-SO400M/14 ([SigLIP](https://github.com/mlfoundations/open_clip)) | WebLI | 224px | 45B | 82.0% |\n",
    "| ViT-SO400M-14-SigLIP-384 ([SigLIP](https://github.com/mlfoundations/open_clip)) | WebLI | 384px | 45B | 83.1% |\n",
    "| ViT-H/14-quickgelu ([DFN](https://www.deeplearning.ai/glossary/neural-networks/)) | DFN-5B | 224px | 39B | 83.4% |\n",
    "| ViT-H-14-378-quickgelu ([DFN](https://www.deeplearning.ai/glossary/neural-networks/)) | DFN-5B | 378px | 44B | 84.4% |\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "아래의 예시에서 `model_name` 과 `checkpoint` 를 설정하여 사용합니다.\n",
    "\n",
    "- `model_name`: OpenCLIP 모델명\n",
    "- `checkpoint`: OpenCLIP 모델의 `Training data` 에 해당하는 이름"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import open_clip\n",
    "import pandas as pd\n",
    "\n",
    "# 사용 가능한 모델/Checkpoint 를 출력\n",
    "pd.DataFrame(open_clip.list_pretrained(), columns=[\"model_name\", \"checkpoint\"]).head(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_experimental.open_clip import OpenCLIPEmbeddings\n",
    "\n",
    "# OpenCLIP 임베딩 함수 객체 생성\n",
    "image_embedding_function = OpenCLIPEmbeddings(\n",
    "    model_name=\"ViT-H-14-378-quickgelu\", checkpoint=\"dfn5b\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "이미지의 경로를 list 로 저장합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 이미지의 경로를 리스트로 저장\n",
    "image_uris = sorted(\n",
    "    [\n",
    "        os.path.join(\"tmp\", image_name)\n",
    "        for image_name in os.listdir(\"tmp\")\n",
    "        if image_name.endswith(\".jpg\")\n",
    "    ]\n",
    ")\n",
    "\n",
    "image_uris"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_teddynote.models import MultiModal\n",
    "from langchain_openai import ChatOpenAI\n",
    "\n",
    "# ChatOpenAI 모델 초기화\n",
    "llm = ChatOpenAI(model=\"gpt-4o-mini\")\n",
    "\n",
    "# MultiModal 모델 설정\n",
    "model = MultiModal(\n",
    "    model=llm,\n",
    "    system_prompt=\"Your mission is to describe the image in detail\",  # 시스템 프롬프트: 이미지를 상세히 설명하도록 지시\n",
    "    user_prompt=\"Description should be written in one sentence(less than 60 characters)\",  # 사용자 프롬프트: 60자 이내의 한 문장으로 설명 요청\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "image 에 대한 description 을 생성합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 이미지 설명 생성\n",
    "model.invoke(image_uris[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![](./images/chroma-02.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 이미지 설명\n",
    "descriptions = dict()\n",
    "\n",
    "for image_uri in image_uris:\n",
    "    descriptions[image_uri] = model.invoke(image_uri, display_image=False)\n",
    "\n",
    "# 생성된 결과물 출력\n",
    "descriptions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from PIL import Image\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# 원본 이미지, 처리된 이미지, 텍스트 설명을 저장할 리스트 초기화\n",
    "original_images = []\n",
    "images = []\n",
    "texts = []\n",
    "\n",
    "# 그래프 크기 설정 (20x10 인치)\n",
    "plt.figure(figsize=(20, 10))\n",
    "\n",
    "# 'tmp' 디렉토리에 저장된 이미지 파일들을 처리\n",
    "for i, image_uri in enumerate(image_uris):\n",
    "    # 이미지 파일 열기 및 RGB 모드로 변환\n",
    "    image = Image.open(image_uri).convert(\"RGB\")\n",
    "\n",
    "    # 4x5 그리드의 서브플롯 생성\n",
    "    plt.subplot(4, 5, i + 1)\n",
    "\n",
    "    # 이미지 표시\n",
    "    plt.imshow(image)\n",
    "\n",
    "    # 이미지 파일명과 설명을 제목으로 설정\n",
    "    plt.title(f\"{os.path.basename(image_uri)}\\n{descriptions[image_uri]}\", fontsize=8)\n",
    "\n",
    "    # x축과 y축의 눈금 제거\n",
    "    plt.xticks([])\n",
    "    plt.yticks([])\n",
    "\n",
    "    # 원본 이미지, 처리된 이미지, 텍스트 설명을 각 리스트에 추가\n",
    "    original_images.append(image)\n",
    "    images.append(image)\n",
    "    texts.append(descriptions[image_uri])\n",
    "\n",
    "# 서브플롯 간 간격 조정\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![](./images/chroma-03.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "아래는 생성한 이미지 description 과 텍스트 간의 유사도를 계산합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "# 이미지와 텍스트 임베딩\n",
    "# 이미지 URI를 사용하여 이미지 특징 추출\n",
    "img_features = image_embedding_function.embed_image(image_uris)\n",
    "# 텍스트 설명에 \"This is\" 접두사를 추가하고 텍스트 특징 추출\n",
    "text_features = image_embedding_function.embed_documents(\n",
    "    [\"This is \" + desc for desc in texts]\n",
    ")\n",
    "\n",
    "# 행렬 연산을 위해 리스트를 numpy 배열로 변환\n",
    "img_features_np = np.array(img_features)\n",
    "text_features_np = np.array(text_features)\n",
    "\n",
    "# 유사도 계산\n",
    "# 텍스트와 이미지 특징 간의 코사인 유사도를 계산\n",
    "similarity = np.matmul(text_features_np, img_features_np.T)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "텍스트 대 이미지 description 간 유사도를 구하고 시각화합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 유사도 행렬을 시각화하기 위한 플롯 생성\n",
    "count = len(descriptions)\n",
    "plt.figure(figsize=(20, 14))\n",
    "\n",
    "# 유사도 행렬을 히트맵으로 표시\n",
    "plt.imshow(similarity, vmin=0.1, vmax=0.3, cmap=\"coolwarm\")\n",
    "plt.colorbar()  # 컬러바 추가\n",
    "\n",
    "# y축에 텍스트 설명 표시\n",
    "plt.yticks(range(count), texts, fontsize=18)\n",
    "plt.xticks([])  # x축 눈금 제거\n",
    "\n",
    "# 원본 이미지를 x축 아래에 표시\n",
    "for i, image in enumerate(original_images):\n",
    "    plt.imshow(image, extent=(i - 0.5, i + 0.5, -1.6, -0.6), origin=\"lower\")\n",
    "\n",
    "# 유사도 값을 히트맵 위에 텍스트로 표시\n",
    "for x in range(similarity.shape[1]):\n",
    "    for y in range(similarity.shape[0]):\n",
    "        plt.text(x, y, f\"{similarity[y, x]:.2f}\", ha=\"center\", va=\"center\", size=12)\n",
    "\n",
    "# 플롯 테두리 제거\n",
    "for side in [\"left\", \"top\", \"right\", \"bottom\"]:\n",
    "    plt.gca().spines[side].set_visible(False)\n",
    "\n",
    "# 플롯 범위 설정\n",
    "plt.xlim([-0.5, count - 0.5])\n",
    "plt.ylim([count + 0.5, -2])\n",
    "\n",
    "# 제목 추가\n",
    "plt.title(\"Cosine Similarity\", size=20)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![](./images/chroma-04.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Vectorstore 생성 및 이미지 추가"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Vectorstore 를 생성하고 이미지를 추가합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# DB 생성\n",
    "image_db = Chroma(\n",
    "    collection_name=\"multimodal\",\n",
    "    embedding_function=image_embedding_function,\n",
    ")\n",
    "\n",
    "# 이미지 추가\n",
    "image_db.add_images(uris=image_uris)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "아래는 이미지 검색된 결과를 이미지로 출력하기 위한 helper class 입니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import base64\n",
    "import io\n",
    "from PIL import Image\n",
    "from IPython.display import HTML, display\n",
    "from langchain.schema import Document\n",
    "\n",
    "\n",
    "class ImageRetriever:\n",
    "    def __init__(self, retriever):\n",
    "        \"\"\"\n",
    "        이미지 검색기를 초기화합니다.\n",
    "\n",
    "        인자:\n",
    "        retriever: LangChain의 retriever 객체\n",
    "        \"\"\"\n",
    "        self.retriever = retriever\n",
    "\n",
    "    def invoke(self, query):\n",
    "        \"\"\"\n",
    "        쿼리를 사용하여 이미지를 검색하고 표시합니다.\n",
    "\n",
    "        인자:\n",
    "        query (str): 검색 쿼리\n",
    "        \"\"\"\n",
    "        docs = self.retriever.invoke(query)\n",
    "        if docs and isinstance(docs[0], Document):\n",
    "            self.plt_img_base64(docs[0].page_content)\n",
    "        else:\n",
    "            print(\"검색된 이미지가 없습니다.\")\n",
    "        return docs\n",
    "\n",
    "    @staticmethod\n",
    "    def resize_base64_image(base64_string, size=(224, 224)):\n",
    "        \"\"\"\n",
    "        Base64 문자열로 인코딩된 이미지의 크기를 조정합니다.\n",
    "\n",
    "        인자:\n",
    "        base64_string (str): 원본 이미지의 Base64 문자열.\n",
    "        size (tuple): (너비, 높이)로 표현된 원하는 이미지 크기.\n",
    "\n",
    "        반환:\n",
    "        str: 크기가 조정된 이미지의 Base64 문자열.\n",
    "        \"\"\"\n",
    "        img_data = base64.b64decode(base64_string)\n",
    "        img = Image.open(io.BytesIO(img_data))\n",
    "        resized_img = img.resize(size, Image.LANCZOS)\n",
    "        buffered = io.BytesIO()\n",
    "        resized_img.save(buffered, format=img.format)\n",
    "        return base64.b64encode(buffered.getvalue()).decode(\"utf-8\")\n",
    "\n",
    "    @staticmethod\n",
    "    def plt_img_base64(img_base64):\n",
    "        \"\"\"\n",
    "        Base64로 인코딩된 이미지를 표시합니다.\n",
    "\n",
    "        인자:\n",
    "        img_base64 (str): Base64로 인코딩된 이미지 문자열\n",
    "        \"\"\"\n",
    "        image_html = f'<img src=\"data:image/jpeg;base64,{img_base64}\" />'\n",
    "        display(HTML(image_html))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Image Retriever 생성\n",
    "retriever = image_db.as_retriever(search_kwargs={\"k\": 3})\n",
    "image_retriever = ImageRetriever(retriever)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 이미지 조회\n",
    "result = image_retriever.invoke(\"A Dog on the street\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![](./images/chroma-05.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 이미지 조회\n",
    "result = image_retriever.invoke(\"Motorcycle with a man\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![](./images/chroma-06.png)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py-test",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
